import os
import sys

sys.path.append(os.path.join(os.path.abspath(os.path.dirname(__file__)), '..'))

import torch
import numpy as np

from utils.exp_2D_cplex_op import single_quadratic_constraint_skip_solver
from utils.exp_4D_cplex_op import fast_quadratic_4D_skip_solver
from utils.write_op import write_pkl
from utils.read_op import read_pkl

def prun_skip_2D_single(
        array_list, 
        s_list, 
        t_list, 
        quadratic_val_list, 
        max_constraint,
        timelimit=None,
        path=None,
        debug=False):
    '''
    Args:
        array_list - list of Numpy 2D [nout, nin] size : nlist
        s_list - list of int size : N
        t_list - list of int size : N
        quadratic_val_list - list of int size : nlist
        max_constraint - float
        timelimit - float
        iterlimit - int
    Return:
        binary_list - list of Numpy 2d [nout, nin] size : nlist
    '''
    print("path:{}".format(path))
    if debug:
        print("array : {}".format(len(array_list)))
        print("quadratic_val_list : {}".format(quadratic_val_list))
    L = len(array_list)
    try:
        u_dict=read_pkl(path+'u_dict.pkl')
        v_dict=read_pkl(path+'v_dict.pkl')
    except OSError:
        u_dict, v_dict = single_quadratic_constraint_skip_solver(
                array_list=array_list,
                quadratic_val_list=quadratic_val_list,
                skip_s_list=s_list,
                skip_t_list=t_list,
                max_constraint=max_constraint,
                olamb=10.0,
                clamb=10.0,
                keep_idx=[0],
                timelimit=timelimit,
                debug=True)
        if path is not 'None':
            write_pkl(u_dict, path+'u_dict.pkl')
            write_pkl(v_dict, path+'v_dict.pkl')
        else: print("path is None")
    if debug: 
        print("u_dict : {}".format(u_dict))
        print("v_dict : {}".format(v_dict))
    cons  = 0
    for lidx, array in enumerate(array_list):
        cons += np.sum(u_dict[lidx+1]) * np.sum(v_dict[lidx]) * quadratic_val_list[lidx]
    print("constraint : {}".format(cons))
    return u_dict, v_dict

def prun_skip_4D(
        array_list, 
        s_list, 
        t_list, 
        quadratic_val_list, 
        max_constraint,
        gamma,
        timelimit=None,
        path=None,
        debug=False):
    '''
    Args:
        array_list - list of Numpy 2D [nout, nin] size : nlist
        s_list - list of int size : N
        t_list - list of int size : N
        quadratic_val_list - list of int size : nlist
        max_constraint - float
        gamma - float
        timelimit - float
        path - string
    Return:
        binary_list - list of Numpy 2d [nout, nin] size : nlist
    '''
    L = len(array_list)
    if debug:
        print("array : {}".format(len(array_list)))
        print("flops_const_list : {}".format(quadratic_val_list))
    try:
        u_dict=read_pkl(path+'u_dict.pkl')
        v_dict=read_pkl(path+'v_dict.pkl')
        q_dict=read_pkl(path+'q_dict.pkl')
    except OSError:
        L = len(array_list)
        solver = single_quadratic_constraint_skip_solver
        u_dict, v_dict, q_dict, binary_array_list = fast_quadratic_4D_skip_solver(
                array_list=array_list,
                quadratic_val_list=quadratic_val_list,
                skip_s_list=s_list,
                skip_t_list=t_list,
                max_constraint=max_constraint, 
                solver=solver,
                gamma=gamma,
                timelimit=timelimit,
                debug=debug)
        if path is not 'None':
            write_pkl(u_dict, path+'u_dict.pkl')
            write_pkl(v_dict, path+'v_dict.pkl')
            write_pkl(q_dict, path+'q_dict.pkl')
        else: print("path is None") 
    return u_dict, v_dict, q_dict

